{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "\n",
    "import random\n",
    "import time\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Calculated means: [0.49139968 0.48215841 0.44653091]\n",
      "Calculated stds: [0.24703223 0.24348513 0.26158784]\n"
     ]
    }
   ],
   "source": [
    "train_data = datasets.CIFAR10(root = 'data', \n",
    "                              train = True, \n",
    "                              download = True)\n",
    "\n",
    "means = train_data.data.mean(axis = (0,1,2)) / 255\n",
    "stds = train_data.data.std(axis = (0,1,2)) / 255\n",
    "\n",
    "print(f'Calculated means: {means}')\n",
    "print(f'Calculated stds: {stds}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transforms = transforms.Compose([\n",
    "                           transforms.RandomHorizontalFlip(),\n",
    "                           transforms.RandomRotation(10),\n",
    "                           transforms.RandomCrop(32, padding = 3),\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize(mean = means, \n",
    "                                                std = stds)\n",
    "                       ])\n",
    "\n",
    "test_transforms = transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize(mean = means, \n",
    "                                                std = stds)\n",
    "                       ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_data = datasets.CIFAR10('data', \n",
    "                              train = True, \n",
    "                              download = True, \n",
    "                              transform = train_transforms)\n",
    "\n",
    "test_data = datasets.CIFAR10('data', \n",
    "                             train = False, \n",
    "                             download = True, \n",
    "                             transform = test_transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train_examples = int(len(train_data)*0.9)\n",
    "n_valid_examples = len(train_data) - n_train_examples\n",
    "\n",
    "train_data, valid_data = torch.utils.data.random_split(train_data, \n",
    "                                                       [n_train_examples, n_valid_examples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training examples: 45000\n",
      "Number of validation examples: 5000\n",
      "Number of testing examples: 10000\n"
     ]
    }
   ],
   "source": [
    "print(f'Number of training examples: {len(train_data)}')\n",
    "print(f'Number of validation examples: {len(valid_data)}')\n",
    "print(f'Number of testing examples: {len(test_data)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "train_iterator = torch.utils.data.DataLoader(train_data, \n",
    "                                             shuffle = True, \n",
    "                                             batch_size = BATCH_SIZE)\n",
    "\n",
    "valid_iterator = torch.utils.data.DataLoader(valid_data, \n",
    "                                             batch_size = BATCH_SIZE)\n",
    "\n",
    "test_iterator = torch.utils.data.DataLoader(test_data, \n",
    "                                            batch_size = BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGGBlock(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, batch_norm):\n",
    "        super().__init__()\n",
    "        \n",
    "        modules = []\n",
    "        modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "        if batch_norm:\n",
    "            modules.append(nn.BatchNorm2d(out_channels))\n",
    "        modules.append(nn.ReLU(inplace=True))\n",
    "    \n",
    "        self.block = nn.Sequential(*modules)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.block(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGG11(nn.Module):\n",
    "    def __init__(self, output_dim, block, pool, batch_norm):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.features = nn.Sequential(\n",
    "            block(3, 64, batch_norm), #in_channels, out_channels\n",
    "            pool(2, 2), #kernel_size, stride\n",
    "            block(64, 128, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(128, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(256, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "        )\n",
    "        \n",
    "        self.classifier = nn.Linear(512, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "extra 64, 128, 256, 512 and 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGG16(nn.Module):\n",
    "    def __init__(self, output_dim, block, pool, batch_norm):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.features = nn.Sequential(\n",
    "            block(3, 64, batch_norm),\n",
    "            block(64, 64, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(64, 128, batch_norm),\n",
    "            block(128, 128, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(128, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(256, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "        )\n",
    "        \n",
    "        self.classifier = nn.Linear(512, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "extra 256, 512 and 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGG19(nn.Module):\n",
    "    def __init__(self, output_dim, block, pool, batch_norm):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.features = nn.Sequential(\n",
    "            block(3, 64, batch_norm),\n",
    "            block(64, 64, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(64, 128, batch_norm),\n",
    "            block(128, 128, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(128, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            block(256, 256, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(256, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            block(512, 512, batch_norm),\n",
    "            pool(2, 2),\n",
    "        )\n",
    "        \n",
    "        self.classifier = nn.Linear(512, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "OUTPUT_DIM = 10\n",
    "BATCH_NORM = True\n",
    "\n",
    "vgg11_model = VGG11(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) \n",
    "vgg16_model = VGG16(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) \n",
    "vgg19_model = VGG19(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG11 has 9,231,114 trainable parameters\n",
      "VGG16 has 14,728,266 trainable parameters\n",
      "VGG19 has 20,040,522 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'VGG11 has {count_parameters(vgg11_model):,} trainable parameters')\n",
    "print(f'VGG16 has {count_parameters(vgg16_model):,} trainable parameters')\n",
    "print(f'VGG19 has {count_parameters(vgg19_model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = VGG11(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG11(\n",
       "  (features): Sequential(\n",
       "    (0): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (2): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (4): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (5): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (7): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (8): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (10): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (11): VGGBlock(\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (classifier): Linear(in_features=512, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_accuracy(fx, y):\n",
    "    preds = fx.argmax(1, keepdim=True)\n",
    "    correct = preds.eq(y.view_as(preds)).sum()\n",
    "    acc = correct.float()/preds.shape[0]\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, device):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for (x, y) in iterator:\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "                \n",
    "        fx = model(x)\n",
    "        \n",
    "        loss = criterion(fx, y)\n",
    "        \n",
    "        acc = calculate_accuracy(fx, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion, device):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for (x, y) in iterator:\n",
    "\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            fx = model(x)\n",
    "\n",
    "            loss = criterion(fx, y)\n",
    "\n",
    "            acc = calculate_accuracy(fx, y)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 1.469 | Train Acc: 46.14%\n",
      "\t Val. Loss: 1.287 |  Val. Acc: 54.75%\n",
      "Epoch: 02 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 1.035 | Train Acc: 63.19%\n",
      "\t Val. Loss: 0.973 |  Val. Acc: 66.83%\n",
      "Epoch: 03 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.856 | Train Acc: 69.85%\n",
      "\t Val. Loss: 0.833 |  Val. Acc: 70.41%\n",
      "Epoch: 04 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.761 | Train Acc: 73.57%\n",
      "\t Val. Loss: 0.780 |  Val. Acc: 73.08%\n",
      "Epoch: 05 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.671 | Train Acc: 76.72%\n",
      "\t Val. Loss: 0.693 |  Val. Acc: 75.95%\n",
      "Epoch: 06 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.619 | Train Acc: 78.62%\n",
      "\t Val. Loss: 0.648 |  Val. Acc: 77.97%\n",
      "Epoch: 07 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.570 | Train Acc: 80.38%\n",
      "\t Val. Loss: 0.631 |  Val. Acc: 78.34%\n",
      "Epoch: 08 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.523 | Train Acc: 82.01%\n",
      "\t Val. Loss: 0.606 |  Val. Acc: 79.89%\n",
      "Epoch: 09 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.489 | Train Acc: 83.04%\n",
      "\t Val. Loss: 0.606 |  Val. Acc: 79.96%\n",
      "Epoch: 10 | Epoch Time: 0m 19s\n",
      "\tTrain Loss: 0.454 | Train Acc: 84.42%\n",
      "\t Val. Loss: 0.575 |  Val. Acc: 80.70%\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 10\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    \n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut5-model.pt')\n",
    "    \n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.487 | Test Acc: 83.41%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut5-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
